import os
import json
import torch
from random import randint
from tqdm import tqdm
import argparse
from diffusers import StableDiffusionPipeline, DiffusionPipeline, CogView4Pipeline, FluxPipeline, StableDiffusion3Pipeline

def load_pipeline(model_name):
    """
    根据指定的模型名称加载相应的模型 Pipeline
    """
    model_paths = {
        "sd35": "target_models/stable-diffusion-3.5-medium",
        "sdxl": "target_models/stable-diffusion-xl-base-1.0",
        "sd-turbo": "target_models/sd-turbo",
        "cogview4": "target_models/CogView4-6B",
        "flux": "target_models/FLUX.1-dev",
        "sd14": "target_models/stable-diffusion-v1-4"
    }

    if model_name not in model_paths:
        raise ValueError(f"Unsupported model: {model_name}")

    model_path = model_paths[model_name]

    # 根据模型名称加载不同的 pipeline
    if model_name == "sd35":
        pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
    elif model_name == "sdxl":
        pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, variant="fp16")
    elif model_name == "sd-turbo":
        pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, variant="fp16")
    elif model_name == "cogview4":
        pipe = CogView4Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
    elif model_name == "flux":
        pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
    elif model_name == "sd14":
        pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)

    return pipe.to("cuda")

def main():
    # 设置命令行参数解析器
    parser = argparse.ArgumentParser(description="Generate images with different models")
    
    # 定义输入输出参数（使用 --input 和 --output）
    parser.add_argument('--input_json', required=True)
    parser.add_argument('--output_dir', required=True)
    parser.add_argument('--output_json', required=True)
    parser.add_argument('--model', required=True, choices=['sd35', 'sdxl', 'sd-turbo', 'cogview4', 'flux', 'sd14'])
    parser.add_argument('--gpu', required=True)
    parser.add_argument('--poison_ratio', required=True, type=float, help="Poison ratio to control prompt_times")
    parser.add_argument('--batch_size', required=True)
    
    # 解析命令行参数
    args = parser.parse_args()

    torch.cuda.set_device(f"cuda:{args.gpu}")

    # 根据 poison_ratio 计算 prompt_times
    prompt_times = int(args.poison_ratio * 4)
    if prompt_times <= 0:
        print("Poison ratio is too low, it must be greater than 0.")
        return
    
    # 加载 poison_list 数据
    with open(args.input_json, "r") as f:
        poison_list = json.load(f)

    # 创建存储图片和 JSON 数据的目录
    os.makedirs(f"{args.output_dir}_{args.poison_ratio}%_{args.model}", exist_ok=True)
    print(f"mkdir {args.output_dir}_{args.poison_ratio}%_{args.model}")
    
    # 准备 JSON 数据
    poison_data = {}

    # 设置批处理大小（可以根据显存调整）
    batch_size = int(args.batch_size)

    total_pairs = len(poison_list)
    total_images_needed = total_pairs * prompt_times  # 总共需要生成的图像对数量

    # 生成一个长度为 total_images_needed 的列表，每个元素表示要处理的某个 "pair" 的内容
    image_pairs = []

    # 将每个 poison_list 中的条目转换成 (original, chosen, reject) 三元组
    for entry in poison_list:
        for _ in range(prompt_times):  # 每个条目生成prompt_times个图像
            image_pairs.append((
                entry["original_prompt"],
                entry["chosen_prompt"],
                entry["reject_prompt"]
            ))

    # 加载对应的模型
    try:
        pipe = load_pipeline(args.model)
    except ValueError as e:
        print(e)
        return

    # 计数器，用来为每个生成的图像分配唯一的编号
    image_counter = 0

    # 分批次处理生成图像
    for i in tqdm(range(0, total_images_needed, batch_size), desc="Generating images in batches"):
        batch_entries = image_pairs[i:i + batch_size]
        
        # 收集每个 batch 的 prompts
        batch_original_prompts = [entry[0] for entry in batch_entries]
        batch_chosen_prompts = [entry[1] for entry in batch_entries]
        batch_reject_prompts = [entry[2] for entry in batch_entries]
        
        seeds = [randint(0, 2**32 - 1) for _ in range(len(batch_entries))]
        generators = [torch.manual_seed(seed) for seed in seeds]

        try:
            print(f"Processing batch {i // batch_size + 1}")
            
            chosen_images = pipe(batch_chosen_prompts, num_inference_steps=50, guidance_scale=7.5, generator=generators, height=512, width=512).images
            reject_images = pipe(batch_reject_prompts, num_inference_steps=50, guidance_scale=7.5, generator=generators, height=512, width=512).images

            for k, (chosen_img, reject_img) in enumerate(zip(chosen_images, reject_images)):
                # 使用全局计数器 `image_counter` 生成唯一的图像编号
                chosen_image_filename = f"chosen_{image_counter}.jpg"
                reject_image_filename = f"reject_{image_counter}.jpg"
                
                chosen_image_path = f"{args.output_dir}_{args.poison_ratio}%_{args.model}/{chosen_image_filename}"
                reject_image_path = f"{args.output_dir}_{args.poison_ratio}%_{args.model}/{reject_image_filename}"
                
                # 保存图像
                chosen_img.save(chosen_image_path)
                reject_img.save(reject_image_path)
                
                key = f"poisoned_pair_{image_counter}"
                poison_data[key] = {
                    "prompt": batch_entries[k][0],
                    "chosen": chosen_image_path,
                    "reject": reject_image_path
                }

                # 递增计数器
                image_counter += 1

        except Exception as e:
            print(f"Error in batch {i // batch_size + 1}: {e}")

    # 将生成的数据保存为 JSON 文件
    with open(f"{args.output_dir}_{args.poison_ratio}%_{args.model}.json", "w") as f:
        json.dump(poison_data, f, indent=4)

    print(f"生成完成，结果保存在 {args.output_dir}_{args.poison_ratio}%_{args.model} 文件夹中")

if __name__ == "__main__":
    main()


# nohup python gen_batch_all.py --input_json "prompt_json/prompt_old.json" --output_dir "poison_data/old" --output_json "poison_data/old" --model "sd35" --gpu 0 --poison_ratio 2 --batch_size 16 > logs/gen_old_2%.log 2>&1 &